{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7704d3bb"
   },
   "source": [
    "(pallas_tpu_pipelining)="
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "teoJ_fUwlu0l"
   },
   "source": [
    "# TPU Pipelining\n",
    "\n",
    "<!--* freshness: { reviewed: '2024-04-08' } *-->"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gAJDZh1gBh-h"
   },
   "source": [
    "This guide serves as a reference for TPU-specific pipelining concerns.\n",
    "We'll review the memory hierarchy and compute units on TPUs, and TPU-specific features of the pipelining API. For a more general-purpose overview of pipelining, see the {ref}`pallas_software_pipelining`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "executionInfo": {
     "elapsed": 54,
     "status": "ok",
     "timestamp": 1744908474512,
     "user": {
      "displayName": "Justin Fu",
      "userId": "17543197034567316452"
     },
     "user_tz": 420
    },
    "id": "ejAVO6ikUUuF"
   },
   "outputs": [],
   "source": [
    "#@title Imports\n",
    "\n",
    "import jax\n",
    "from jax.experimental import pallas as pl\n",
    "from jax.experimental.pallas import tpu as pltpu\n",
    "import jax.numpy as jnp\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0e212a5e"
   },
   "source": [
    "(tpu_and_its_memory_spaces)=\n",
    "\n",
    "## TPU and its memory spaces"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NnWW9GV4kW6P"
   },
   "source": [
    "A TPU and its TensorCore consist of memory spaces (where arrays can reside),\n",
    "registers (which temporarily store scalar and array values) and compute units\n",
    "(that do computation with values in registers).\n",
    "Below is a diagram of a TPU in which `x` and `y` are arrays that live in\n",
    "high-bandwidth memory (HBM):\n",
    "\n",
    "![TPU Memory Space Cartoon.png]()\n",
    "\n",
    "Let's talk about the components of this diagram in more detail:\n",
    "\n",
    "* **Memory spaces**: A TPU has high-bandwidth memory (HBM) which is what we\n",
    "  often think of as \"device memory\".\n",
    "  There is also vector memory (VMEM),\n",
    "  a cache meant for storing vector and array values, and scalar memory (SMEM),\n",
    "  a cache designed to store scalar values.\n",
    "* **Registers**: A TensorCore has two main types of registers: vector\n",
    "  registers (VREGs) store array values, and scalar registers (SREGs) store\n",
    "  scalar values.\n",
    "  Values can be loaded into memory from their respective caches (VMEM for\n",
    "  VREGs and SMEM for SREGs).\n",
    "* **Compute units**: A TensorCore has a scalar unit, vector unit (VPU) and\n",
    "  matrix unit (MXU) that can do numerical computation. Each of these compute units can operate asynchronously, but this is managed by the TPU compiler and thus from the programmer's perspective a TPU program is single-threaded.\n",
    "  Compute units operate on values that live in SREGs and VREGs and output\n",
    "  values into those registers as well."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8Tl3wt5Wk3Ek"
   },
   "source": [
    "## TPU-specific Pipelining Features\n",
    "\n",
    "Pallas TPU supports the following platform-specific features."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1jg5WmExk47l"
   },
   "source": [
    "### TPU Memory Spaces\n",
    "\n",
    "Pallas exposes all levels of the TPU memory hierarchy to users. The following table maps from Pallas TPU memory spaces to their standard memory types (DRAM/SRAM):\n",
    "\n",
    "| Pallas Enum | TPU Memory Space | Type (DRAM/SRAM) |\n",
    "| --- | --- | --- |\n",
    "| `pltpu.MemorySpace.ANY` | HBM (usually) or VMEM | DRAM |\n",
    "| `pltpu.MemorySpace.VMEM` | VMEM | SRAM |\n",
    "| `pltpu.MemorySpace.SMEM` | SMEM | SRAM |\n",
    "| `pltpu.MemorySpace.SEMAPHORE` | Semaphore | SRAM |\n",
    "\n",
    "- `MemorySpace.VMEM` denotes vector SRAM. It is the default memory space if nothing is specified.\n",
    "- `MemorySpace.SMEM` denotes scalar SRAM. Only scalar loads and stores can be performed to/from SMEM.\n",
    "- `MemorySpace.ANY` is a hint to the compiler that the memory space is unconstrained. In most cases, XLA will place this buffer in HBM. A buffer assigned to the `ANY` memory space cannot be dereferenced normally using array indexing syntax (e.g. `x[...]`). Instead, we must first copy the values into a VMEM or SMEM buffer using `pltpu.sync_copy` or `pltpu.async_copy`.\n",
    "- `MemorySpace.SEMAPHORE` is used to allocate semaphores for constructing barriers or tracking asynchronous operations. It is also possible to return semaphores from the kernel for building asynchronous kernels - this is an experimental feature; see {ref}`pallas_async` for more details.\n",
    "\n",
    "Pipelining on TPUs is typically done between HBM (DRAM) to VMEM (Vector SRAM). The default behavior for `pallas_call` on TPU is that arguments to `pallas_call` are assumed to live in HBM, and inputs to the user kernel body are stored in VMEM.\n",
    "\n",
    "While not specific to pipelining, it is possible to gain manual control over the memory space of input and output buffers, you can specify the `memory_space` argument on a `BlockSpec`. Note that pipelining is not allowed unless the `memory_space` is marked as `VMEM`. Memory spaces can also be used to specify scratch arguments to a kernel via the `scratch_shapes` argument on `pallas_call`. Scratch buffers are persistent across kernel iterations and are useful for storing intermediate results such as partial accumulations and reductions. A scratch buffer must reside in `VMEM`, `SMEM`, or `SEMAPHORE`.\n",
    "\n",
    "As an example for using multiple manual memory space assignments in a kernel, the following program copies a slice of an HBM buffer `x_hbm_ref` into a scratch VMEM buffer `scratch_vmem_ref` before using it for arithmetic and storing the result into an output VMEM buffer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "executionInfo": {
     "elapsed": 65,
     "status": "ok",
     "timestamp": 1744908591430,
     "user": {
      "displayName": "Justin Fu",
      "userId": "17543197034567316452"
     },
     "user_tz": 420
    },
    "id": "zcqz1CA_o50a"
   },
   "outputs": [],
   "source": [
    "def hbm_vmem_kernel(x_hbm_ref, out_vmem_ref, scratch_vmem_ref):\n",
    "  pltpu.sync_copy(x_hbm_ref.at[0:1], scratch_vmem_ref)\n",
    "  out_vmem_ref[...] = scratch_vmem_ref[...] + 1\n",
    "\n",
    "x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32)\n",
    "out = pl.pallas_call(hbm_vmem_kernel,\n",
    "  in_specs=[pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)],\n",
    "  out_shape=jax.ShapeDtypeStruct((1, 128), jnp.float32),\n",
    "  scratch_shapes=(pltpu.MemorySpace.VMEM(shape=(1, 128), dtype=jnp.float32),)\n",
    ")(x)\n",
    "\n",
    "np.testing.assert_allclose(out, x[0:1] + 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multiple Buffering\n",
    "\n",
    "Multiple buffering can be specified on a per-argument basis to the pipeline via the `pipeline_mode` option on `pl.BlockSpec`. To do so, pass a `pl.Buffered` object to `pl.BlockSpec` specifying the number of buffers to allocate for this particular argument:\n",
    "\n",
    "```python\n",
    "pl.BlockSpec(\n",
    "  pipeline_mode=pl.Buffered(buffer_count=buffer_count)\n",
    ")\n",
    "```\n",
    "\n",
    "The default buffer count is 2 for all inputs and outputs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(pallas_tpu_emit_pipeline)=\n",
    "\n",
    "### pltpu.emit_pipeline\n",
    "\n",
    "`pltpu.emit_pipeline` is a pipelining API implemented in Pallas that allows you to construct pipelines inside of a kernel rather than only on kernel entry. This several use-cases over using `pl.pallas_call`, such as:\n",
    "- For constructing nested pipelines. For example, an outer pipeline that communicates between chips, and an inner pipeline that performs HBM-VMEM pipelining.\n",
    "- For using `emit_pipeline` specific features such as lookahead prefetch and dynamic block shapes (covered below).\n",
    "\n",
    "`pltpu.emit_pipeline` follows a similar signature to `pl.pallas_call` and requires you to specify a body `kernel`, a grid, and block specs for inputs and outputs:\n",
    "\n",
    "```python\n",
    "def emit_pipeline(\n",
    "    kernel: Callable,\n",
    "    grid: tuple[int],\n",
    "    in_specs: PyTree[BlockSpec] = None,\n",
    "    out_specs: PyTree[BlockSpec] = None,\n",
    "    dimension_semantics: tuple[GridDimensionSemantics] = None,\n",
    "    core_axis: int | None = None,\n",
    ") -> Callable:\n",
    "  ... # Returns a custom pipeline given an inner kernel and BlockSpecs.\n",
    "```\n",
    "\n",
    "The `dimension_semantics` and `core_axis` arguments are used for partitioning the kernel grid over Megacore (see below)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Lookahead Prefetch\n",
    "\n",
    "Lookahead prefetch is a pipelining feature where the pipeline will attempt to prefetch the next input block as soon as a buffering slot is available, rather than the iteration directly before it would be used. For example, if the kernel had a grid of `(8,)` and the block indices to fetch on each iteration were `0, 0, 0, 0, 1, 1, 1, 1`, then lookahead prefetch will begin fetching both blocks `0` and `1` on iteration 0, whereas the standard pipeline schedule would fetch block `0` on iteration 0 but not begin fetching block `1` until iteration 3. There is a small amount of control flow overhead in performing lookahead so it is disabled by default.\n",
    "\n",
    "Lookahead is primarily useful when there is a variable amount of compute work in each block, such as when some blocks contain skipped or a reduced amount of work. In these cases, there may not be enough compute work in the iteration immediately preceding the step when the block is needed to fully overlap with the memory transfer. Therefore, we would like to begin fetching blocks earlier in the pipeline.\n",
    "\n",
    "Lookahead prefetch can be used in conjunction with multiple buffering and can likewise be enabled by passing `pl.Buffered` into the `pipeline_mode` argument:\n",
    "```python\n",
    "pl.BlockSpec(\n",
    "  pipeline_mode=pl.Buffered(buffer_count=buffer_count, use_lookahead=True)\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dynamic Block Shapes\n",
    "\n",
    "`pltpu.emit_pipeline` supports pipelining over blocks with dynamic but bounded shapes. In order to specify such an block shape, the dynamic-sized dimension in the block should be marked with `pl.BoundedSlice(max_size)` rather than a static integer size, where `max_size` is the maximum size of the block. In addition, the corresponding index returned by `index_map` should be a dynamic slice constructed via `pl.ds(start, size)` where both `start` and `size` are _element_ indices (not block indices) and can be dynamic.\n",
    "\n",
    "The following is an example for a block spec with a dynamic first dimension:\n",
    "\n",
    "```python\n",
    "pl.BlockSpec(\n",
    "   block_shape=(pl.BoundedSlice(32), 256),\n",
    "   index_map=lambda *grid_idxs: (pl.ds(start, end), 0),\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The following kernel copies `x` to the output in dynamic-sized chunks\n",
    "# passed in via `slices`.\n",
    "\n",
    "def dynamic_block_example_kernel(x_hbm, slices_hbm, o_hbm, slices_smem):\n",
    "    pltpu.sync_copy(slices_hbm, slices_smem)  # Copy slices into SMEM.\n",
    "    def pipeline_body(x_vmem, o_vmem):\n",
    "        o_vmem[...] = x_vmem[...]\n",
    "    def index_map(i):\n",
    "        start = slices_smem[i, 0]\n",
    "        size = slices_smem[i, 1] - slices_smem[i, 0]\n",
    "        return (pl.ds(start, size), 0)\n",
    "    block_spec = pl.BlockSpec(block_shape=(pl.BoundedSlice(8), 128),\n",
    "                              index_map=index_map)\n",
    "    pltpu.emit_pipeline(\n",
    "        pipeline_body,\n",
    "        grid=(slices.shape[0],),\n",
    "        in_specs=[block_spec],\n",
    "        out_specs=block_spec\n",
    "    )(x_hbm, o_hbm)\n",
    "\n",
    "x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32)\n",
    "slices = jnp.array([[0, 2], [2, 3], [3, 5], [5, 8]], dtype=jnp.int32)\n",
    "\n",
    "hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)\n",
    "out = pl.pallas_call(dynamic_block_example_kernel,\n",
    "               in_specs=[hbm_block_spec, hbm_block_spec],\n",
    "               out_specs=hbm_block_spec,\n",
    "               out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),\n",
    "               scratch_shapes=(pltpu.MemorySpace.SMEM(slices.shape, jnp.int32),)\n",
    "              )(x, slices)\n",
    "\n",
    "np.testing.assert_allclose(x, out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KvPFez9N8cKJ"
   },
   "source": [
    "(pallas_tpu_megacore)=\n",
    "\n",
    "### TPUs in Megacore configuration"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0f4HAVzQ8n71"
   },
   "source": [
    "Some TPU chips have two TensorCores but appear as one device to JAX users.\n",
    "This is called \"megacore\".\n",
    "The separate TensorCores have their own separate VMEM, VREGs, SMEM, SREGs\n",
    "and compute units but *share HBM*.\n",
    "\n",
    "![TPU Memory Space Cartoon (Megacore).png]()\n",
    "\n",
    "Conceptually, TPUs in Megacore behave like very simple GPUs, i.e. they have\n",
    "only two threads.\n",
    "How do we modify our kernels to utilize both TensorCores simultaneously?\n",
    "\n",
    "The basic idea is that if we have embarrassingly parallel dimensions in our\n",
    "computation, we can split up those dimensions across the TensorCores.\n",
    "We can indicate which dimensions are parallelizable by providing an\n",
    "annotation to `pallas_call` called `dimension_semantics`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "executionInfo": {
     "elapsed": 106,
     "status": "ok",
     "timestamp": 1744910274556,
     "user": {
      "displayName": "Justin Fu",
      "userId": "17543197034567316452"
     },
     "user_tz": 420
    },
    "id": "nQNa8RaQ-TR1",
    "outputId": "29c0b574-3528-49a5-8a88-b6987efc69ce"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[2., 2., 2., ..., 2., 2., 2.],\n",
       "       [2., 2., 2., ..., 2., 2., 2.],\n",
       "       [2., 2., 2., ..., 2., 2., 2.],\n",
       "       ...,\n",
       "       [2., 2., 2., ..., 2., 2., 2.],\n",
       "       [2., 2., 2., ..., 2., 2., 2.],\n",
       "       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):\n",
    "  # Load x and y from VMEM into VREGs\n",
    "  x_vregs = x_vmem_ref[:, :]\n",
    "  y_vregs = y_vmem_ref[:, :]\n",
    "  # Execute a vectorized add\n",
    "  z_vregs = x_vregs + y_vregs\n",
    "  # Store the output values in VREGs back into VMEM\n",
    "  z_vmem_ref[:, :] = z_vregs\n",
    "\n",
    "def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:\n",
    "  block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))\n",
    "  return pl.pallas_call(\n",
    "      add_matrices_kernel,\n",
    "      out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
    "      in_specs=[block_spec, block_spec],\n",
    "      out_specs=block_spec,\n",
    "      grid=(2,),\n",
    "      compiler_params=pltpu.CompilerParams(\n",
    "          dimension_semantics=(\"parallel\",))\n",
    "  )(x, y)\n",
    "\n",
    "x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n",
    "add_matrices_pipelined_megacore(x, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xG51AiUC-8cl"
   },
   "source": [
    "`dimension_semantics` should be a tuple of same length as `grid` where each\n",
    "entry is either `\"parallel\"` or `\"arbitrary\"`. `\"parallel\"` indicates to Pallas that the iterations of the for loop corresponding to that dimension can be executed independently without affecting the correctness of the program. `\"arbitrary\"` indicates to Pallas that there can be no assumptions made about this grid dimension and it therefore cannot be parallelized.\n",
    "\n",
    "By specifying `dimension_semantics`, we now execute the kernel\n",
    "simultaneously on each TensorCore. Pallas will handle splitting up the grid\n",
    "automatically.\n",
    "\n",
    "> Note that Megacore is only currently available on TPU `v4` and TPU `v5p`. Supplying `dimension_semantics` annotations is a no-op on other platforms, but *not* specifying it will result in only one TensorCore being used (even if there are more than one available).\n",
    "\n",
    "When using `pltpu.emit_pipeline`, `core_axis` should be passed into `emit_pipeline`. `core_axis` should be the index of a parallel grid axis to partition the grid on. For example, the following template can be used to partition the kernel over a leading parallel grid dimension:\n",
    "\n",
    "```python\n",
    "def kernel_body(...):\n",
    "  def inner_pipeline_body(...):\n",
    "    ...\n",
    "  pltpu.emit_pipeline(inner_pipeline_body,\n",
    "                      grid=(4, 4), \n",
    "                      core_axis=0,\n",
    "                      dimension_semantics=(\"parallel\", \"sequential\"))\n",
    "\n",
    "pl.pallas_call(\n",
    "      kernel_body,\n",
    "      grid=(num_cores,),\n",
    "      compiler_params=pltpu.CompilerParams(\n",
    "          dimension_semantics=(\"parallel\",))\n",
    "  )\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "last_runtime": {
    "build_target": "//experimental/users/justinfu/pallas:colab",
    "kind": "private"
   },
   "provenance": []
  },
  "jupytext": {
   "formats": "ipynb,md:myst"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
